import math
import torch
import torch.nn as nn
import torch.nn.functional as F
Input Data
| Source | Target |
| I am from China | 我来自中国 |
| You and me are best friends | 你我是最好的朋友 |
batch_size=2, num_steps=8, “<unk>”=0,
“<pad>” = 1, “<bos>”=2, “<eos>”=3.
‘I am from China’ -> \([4, 5, 6, 7, 1, 1, 1, 1]\)
‘You and me are best friends’ -> \([8, 9, 10, 11, 12, 13, 1, 1]\)
\[X:\begin{bmatrix}4&5&6&7&1&1\\8&9&10&11&12&13\end{bmatrix}\ \ X\_valid\_len:\begin{bmatrix}4&6\end{bmatrix}\]
‘我来自中国’ -> \([4, 5, 6, 7, 8, 1, 1, 1]\)
‘你我是最好的朋友’ -> \([9, 4, 11, 12, 13, 14, 15, 16]\)
\[Y:\begin{bmatrix}4&5&6&7&8&1&1&1&1\\9&10&4&11&12&13&14&15&16\end{bmatrix}\ \ Y\_valid\_len:\begin{bmatrix}5&8\end{bmatrix}\]
Basic Functions
sequence_mask
def sequence_mask(X, valid_len, value=0.0):
'''
:param X: (batch_size, seq_len, input_dim)
:param valid_len: (batch_size, ) or (batch_size, seq_len) <I will discuss this!>
(query_lens, num_hiddens) * (key_lens, num_hiddens)^T = (query_lens, key_lens)
:param X: (batch_size * query_lens, num_hiddens)
:param valid_len: (batch_size, ) --torch.repeat_interleave()--> valid_lens: (batch_size*query_lens, )
'''
maxlen = X.shape[1]
mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
X = torch.ones(2, 6, 8) # shape: (batch_size, seq_len, input_dim)
valid_len = torch.tensor([4, 6]).reshape(2, ) # shape: (batch_size, )
print(sequence_mask(X, valid_len, -99))
## tensor([[[ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [-99., -99., -99., -99., -99., -99., -99., -99.],
## [-99., -99., -99., -99., -99., -99., -99., -99.]],
##
## [[ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.]]])
(1, seq_len) < (batch_size, 1) \[ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1\end{bmatrix}\\\begin{bmatrix}2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\\\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1&1&1\end{bmatrix}\\\begin{bmatrix}2&2&2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ mask:\begin{bmatrix}\begin{bmatrix}True&False&False\end{bmatrix}\\\begin{bmatrix}True&True&False\end{bmatrix}\end{bmatrix} \]
masked_softmax
def masked_softmax(X, valid_lens):
'''
query: (2, 6, 14) * key: (2, 8, 14)^T = score: (2, 6, 8)
score: (2, 6, 8) * value: (2, 8, 14) = (2, 6, 14)
masked_softmax() is used to mask score.
:param X: (batch_size, query_lens, key_lens)
:param valid_lens: (batch_size, )
'''
if valid_lens is None:
return F.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else: # I will discuss this after!
valid_lens = valid_lens.reshape(-1)
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return F.softmax(X.reshape(shape), dim=-1)
masked_softmax(torch.rand(2, 6, 8), torch.tensor([4, 6]))
## tensor([[[0.2318, 0.2245, 0.2486, 0.2951, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.3212, 0.2489, 0.2753, 0.1546, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2987, 0.1641, 0.1543, 0.3829, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2037, 0.1553, 0.2646, 0.3764, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2201, 0.2857, 0.2216, 0.2725, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2362, 0.2553, 0.3232, 0.1853, 0.0000, 0.0000, 0.0000, 0.0000]],
##
## [[0.2086, 0.2447, 0.1044, 0.1487, 0.1750, 0.1186, 0.0000, 0.0000],
## [0.1459, 0.1679, 0.2042, 0.1070, 0.1879, 0.1871, 0.0000, 0.0000],
## [0.1071, 0.1046, 0.1502, 0.2344, 0.2091, 0.1947, 0.0000, 0.0000],
## [0.1334, 0.0907, 0.0978, 0.2194, 0.2305, 0.2282, 0.0000, 0.0000],
## [0.1899, 0.1082, 0.2001, 0.1261, 0.2665, 0.1093, 0.0000, 0.0000],
## [0.2057, 0.1857, 0.1246, 0.1821, 0.1901, 0.1119, 0.0000, 0.0000]]])
DotProductAttention
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
'''
:param queries: (batch_size, query_lens, num_hiddens)
:param keys: (batch_size, key_lens, num_hiddens)
:param values: (batch_size, value_lens, num_hiddens)
'''
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) # O(n*n*d) self-Attention
self.attention_weights = masked_softmax(scores, valid_lens) # O(n*n) self-Attention
return torch.bmm(self.dropout(self.attention_weights), values) # O(n*n*d) self-Attention
queries, keys, values = torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14))
attention = DotProductAttention(dropout=0.5)
attention.eval()
## DotProductAttention(
## (dropout): Dropout(p=0.5, inplace=False)
## )
print(attention(queries, keys, values, torch.tensor([3, 4])).shape)
## torch.Size([2, 6, 14])
MultiHeadAttention
Suppose an input matrix of dimension (1, seq_lens=6, input_size=14), and number of heads is 2. Head1 processes the red area, and head2 processes the blue area. transpose_qkv function will transpose \((1, 6, 14)\) into \((1*2, 6, 7)\) to facilitate the parallelized computation. transpose_output function will turn it into its original form.
\[\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix}\mathop{\longrightarrow}^{transpose\_qkv()}\begin{matrix}\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\end{bmatrix}\\ \begin{bmatrix}\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix} \end{matrix}\]
def transpose_qkv(X, num_heads):
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
Positional encoding
\[\boldsymbol{Embed\_X}+\boldsymbol{P}=\begin{bmatrix}x_{11}&x_{12}&\cdots&x_{1d}\\x_{21}&x_{22}&\cdots&x_{2d}\\\vdots&\vdots&\ddots&\vdots\\x_{n1}&x_{n2}&\cdots&x_{nd}\end{bmatrix}+\begin{bmatrix}p_{11}&x_{12}&\cdots&p_{1d}\\p_{21}&p_{22}&\cdots&p_{2d}\\\vdots&\vdots&\ddots&\vdots\\p_{n1}&p_{n2}&\cdots&p_{nd}\end{bmatrix}\], where \(n\) represents seq_lens, \(d\) represents embedding size, \(\boldsymbol{P}\) is the positional encoding matrix.
\[p_{i, 2j}=\sin\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\ \ \ \ \ \ \ p_{i, 2j+1}=\cos\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\], where \(i=0,1,\cdots,n-1\) and \(j=0, 1, \cdots,d/2-1\).
Properties:
\[\begin{bmatrix}p_{i+k, 2j}\\p_{i+k, 2j+1}\end{bmatrix}=\begin{bmatrix}\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\\-\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\end{bmatrix}\begin{bmatrix}p_{i,2j}\\p_{i,2j+1}\end{bmatrix}\]
\[\begin{bmatrix}p_{i,0}&p_{i,1}&\cdots&p_{i,d-1}\end{bmatrix}\begin{bmatrix}p_{i+k,0}\\p_{i+k, 1}\\\vdots\\p_{i+k, d-1}\end{bmatrix}=\cos\Big{(}\frac{k}{1000^{0/d}}\Big{)}+\cos\Big{(}\frac{k}{10000^{2/d}}\Big{)}+\cdots+\cos\Big{(}\frac{k}{10000^{(d-2)/d}}\Big{)}\]
Assuming \(d=512\), the inner product of vectors as \(k\) increases is shown below.
import plotly.express as px
import pandas as pd
import numpy as np
def myfunc(k, d):
exponents = np.arange(0, d, 2)/d
a = k / np.power(10000, exponents)
a = np.cos(a)
# print(a)
return np.sum(a)
k = np.arange(0, 200)
y = [myfunc(item, 512) for item in k]
df = pd.DataFrame({'k': k, 'y': y})
fig = px.scatter(df, x='k', y='y', width=768, height=474)
fig.show()
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros((1, max_len, num_hiddens))
X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) /
torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X += self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
Add&Norm
Suppose I have an input of dimension \((batch\_size=2, seq\_lens=2, input\_size=4)\).
\[ \begin{bmatrix}x^{(1)}_{11}&x^{(1)}_{12}&x^{(1)}_{13}&x^{(1)}_{14}\\x^{(1)}_{21}&x^{(1)}_{22}&x^{(1)}_{23}&x^{(1)}_{24}\end{bmatrix}\\ \begin{bmatrix}x^{(2)}_{11}&x^{(2)}_{12}&x^{(2)}_{13}&x^{(2)}_{14}\\x^{(2)}_{21}&x^{(2)}_{22}&x^{(2)}_{23}&x^{(2)}_{24}\end{bmatrix} \]
The normalization operator nn.LayerNorm(4) is applied on
every token.
\[ mean^{(1)}_1=\frac{1}{4}\sum_j^4x^{(1)}_{1j}\\Var^{(1)}_1=\frac{1}{4}\sum_j\big{(}x^{(1)}_{1j}-mean_1^{(1)}\big{)}^2 \]
The normalization operator nn.LayerNorm([2, 4]) is
applied on every input text.
\[ mean^{(1)}=\frac{1}{2\times4}\sum^2_i\sum^4_jx^{(1)}_{ij}\\Var^{(1)}=\frac{1}{2\times4}\sum_i^2\sum_j^4\big{(}x^{(1)}_{ij}-mean^{(1)}\big{)}^2 \]
ln1 = nn.LayerNorm(4)
ln2 = nn.LayerNorm([2, 4])
with torch.no_grad():
# X shape: (2, 2, 4)
X = torch.tensor([
[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[5, 6, 7, 8],
[5, 1, 0, -1]]
], dtype=torch.float32)
print(ln1(X))
print(ln2(X))
## tensor([[[-1.3416, -0.4472, 0.4472, 1.3416],
## [-1.3416, -0.4472, 0.4472, 1.3416]],
##
## [[-1.3416, -0.4472, 0.4472, 1.3416],
## [ 1.6465, -0.1098, -0.5488, -0.9879]]])
## tensor([[[-1.5275, -1.0911, -0.6547, -0.2182],
## [ 0.2182, 0.6547, 1.0911, 1.5275]],
##
## [[ 0.3538, 0.6683, 0.9829, 1.2974],
## [ 0.3538, -0.9042, -1.2187, -1.5332]]])
class AddNorm(nn.Module):
def __init__(self, normalized_shape, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
return self.ln(X + self.dropout(Y))
ForwardWiseFFN
class PositionWiseFFN(nn.Module):
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
Encoder&Decoder
Encoder
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block" + str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) # I haven't dived into this line.
self.attention_weights = [None] * len(self.blks) # self.attention_weights is the score matrix
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X
encoder = TransformerEncoder(200, 14, 14, 14, 14, [6, 14], 14, 28, 2, 6, 0.5)
#encoder.eval()
X = torch.ones((2, 6), dtype=torch.long)
valid_lens = torch.tensor([4, 6], dtype=torch.long)
print(encoder(X, valid_lens).shape)
## torch.Size([2, 6, 14])
Decoder
class DecoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
class AttentionDecoder(nn.Module):
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
def attention_weights(self):
raise NotImplementedError
class TransformerDecoder(AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
the first multi-head layer in decoder
query after being transposed by multi-heads: \((4, 8, 7)\).
key after being transposed by multi-heads: \((4, 8, 7)\rightarrow^T(4, 7, 8)\).
\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\begin{bmatrix}1&0&0&0&0&0&0&0\\s_{21}&s_{22}&0&0&0&0&0&0\\s_{31}&s_{32}&s_{33}&0&0&0&0&0\\s_{41}&s_{42}&s_{43}&s_{44}&0&0&0&0\\s_{51}&s_{52}&s_{53}&s_{54}&s_{55}&0&0&0\\s_{61}&s_{62}&s_{63}&s_{64}&s_{65}&s_{66}&0&0\\s_{71}&s_{72}&s_{73}&s_{74}&s_{75}&s_{76}&s_{77}&0\\s_{81}&s_{82}&s_{83}&s_{84}&s_{85}&s_{86}&s_{87}&s_{88}\end{bmatrix}, \forall i=2,\cdots8; \sum_j s_{ij}=1\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \]
The above score matrix indicates that the k-th token can only compute self-attention with the previous k tokens.
the second multi-head layer in decoder
query after being transposed by multi-heads:\((4, 8, 7)\).
key after being transposed by multi-heads:\((4, 7, 6)\rightarrow^T(4, 7, 6)\).
\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\vdots\\0\cdots0\\0\cdots0\end{bmatrix}_{7\times6}\rightarrow\begin{bmatrix}\cdots&0&0\\\ddots&\vdots&\vdots\\\cdots&0&0\end{bmatrix}_{8\times6}\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots \]
The above score matrix indicates that the invalid tokens in encoder outputs is ignored.
Model
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)
Train And Prediction
Masked softmax loss
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
def forward(self, pred, label, valid_len):
'''
:pred's shape: (batch_size, num_steps, vocab_size)
:label's shape: (batch_size, num_steps)
:valid_len's shape: (batch_size, )
'''
weights = torch.ones_like(label)
weights = sequence_mask(weights, valid_len)
self.reduction = 'none'
unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
pred.permute(0, 2, 1), label)
# Above is the correct code for calculating cross entropy loss when pred and label have batch dimension.
weighted_loss = (unweighted_loss * weights).mean(dim=1)
return weighted_loss # Each sequence has a loss value.
Train function
def train_seq2seq(model, data_iter, lr, num_epochs, tgt_vocab, device):
def xavier_init_weights(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
if type(m) == nn.GRU:
for param in m._flat_weights_names:
if "weight" in param:
nn.init.xavier_uniform_(m._parameters[param])
model.apply(xavier_init_weights)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss = MaskedSoftmaxCELoss()
model.train()
for epoch in range(num_epochs):
myloss = 0
for batch in data_iter:
optimizer.zero_grad()
X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
# below is called teacher forcing
bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)
dec_input = torch.cat([bos, Y[:, :-1]], 1)
Y_hat, _ = model(X, dec_input, X_valid_len)
l = loss(Y_hat, Y, Y_valid_len)
l.sum().backward()
# d2l.grad_clipping(net, 1)
num_tokens = Y_valid_len.sum()
optimizer.step()
myloss += l.sum() / num_tokens
if (epoch + 1) % 10 == 0:
print("loss: {:.4f}".format(myloss))
Prediction
def truncate_pad(line, num_steps, padding_token):
if len(line) > num_steps:
return line[:num_steps]
else:
return line + [padding_token] * (num_steps - len(line))
def predict_seq2seq(model, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False):
model.eval()
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
enc_outputs = model.encoder(enc_X, enc_valid_len)
dec_state = model.decoder.init_state(enc_outputs, enc_valid_len)
dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
output_seq, attention_weight_seq = [], []
for _ in range(num_steps):
Y, dec_state = model.decoder(dec_X, dec_state)
dec_X = Y.argmax(dim=2)
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
if save_attention_weights:
attention_weight_seq.append(model.decoder.attention_weights)
if pred == tgt_vocab['<eos>']:
break
output_seq.append(pred)
return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
Transformer Code
The following code is a copy of all aforementioned Transformer architecture code, convenient for copying.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def sequence_mask(X, valid_len, value=0.0):
maxlen = X.shape[1]
mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
def masked_softmax(X, valid_lens):
if valid_lens is None:
return F.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else: # I will discuss this after!
valid_lens = valid_lens.reshape(-1)
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return F.softmax(X.reshape(shape), dim=-1)
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
def transpose_qkv(X, num_heads):
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros((1, max_len, num_hiddens))
X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) /
torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X += self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
class AddNorm(nn.Module):
def __init__(self, normalized_shape, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
return self.ln(X + self.dropout(Y))
class PositionWiseFFN(nn.Module):
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block" + str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) # I haven't dived into this line.
self.attention_weights = [None] * len(self.blks) # self.attention_weights is the score matrix
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X
class DecoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
class AttentionDecoder(nn.Module):
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
def attention_weights(self):
raise NotImplementedError
class TransformerDecoder(AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)